{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os \n",
    "SPAM_PATH = os.path.join( \"spam\")\n",
    "HAM_DIR = os.path.join(SPAM_PATH, \"easy_ham\")\n",
    "SPAM_DIR = os.path.join(SPAM_PATH, \"spam\")\n",
    "\n",
    "ham_files = [HAM_DIR + '/'+name for name in sorted(os.listdir(HAM_DIR)) if len(name) > 20]\n",
    "spam_files = [SPAM_DIR + '/'+name for name in sorted(os.listdir(SPAM_DIR)) if len(name) > 20]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import email\n",
    "import email.policy\n",
    "def load_email(fn):\n",
    "    with open(fn, \"rb\") as f:\n",
    "        return email.parser.BytesParser(policy = email.policy.default).parse(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "ham_emails = [load_email(fn) for fn in ham_files]\n",
    "spam_emails = [load_email(fn) for fn in spam_files]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "'Date:        Wed, 21 Aug 2002 10:54:46 -0500\\n    From:        Chris Garrigues <cwg-dated-1030377287.06fa6d@DeepEddy.Com>\\n    Message-ID:  <1029945287.4797.TMDA@deepeddy.vircio.com>\\n\\n\\n  | I can\\'t reproduce this error.\\n\\nFor me it is very repeatable... (like every time, without fail).\\n\\nThis is the debug log of the pick happening ...\\n\\n18:19:03 Pick_It {exec pick +inbox -list -lbrace -lbrace -subject ftp -rbrace -rbrace} {4852-4852 -sequence mercury}\\n18:19:03 exec pick +inbox -list -lbrace -lbrace -subject ftp -rbrace -rbrace 4852-4852 -sequence mercury\\n18:19:04 Ftoc_PickMsgs {{1 hit}}\\n18:19:04 Marking 1 hits\\n18:19:04 tkerror: syntax error in expression \"int ...\\n\\nNote, if I run the pick command by hand ...\\n\\ndelta$ pick +inbox -list -lbrace -lbrace -subject ftp -rbrace -rbrace  4852-4852 -sequence mercury\\n1 hit\\n\\nThat\\'s where the \"1 hit\" comes from (obviously).  The version of nmh I\\'m\\nusing is ...\\n\\ndelta$ pick -version\\npick -- nmh-1.0.4 [compiled on fuchsia.cs.mu.OZ.AU at Sun Mar 17 14:55:56 ICT 2002]\\n\\nAnd the relevant part of my .mh_profile ...\\n\\ndelta$ mhparam pick\\n-seq sel -list\\n\\n\\nSince the pick command works, the sequence (actually, both of them, the\\none that\\'s explicit on the command line, from the search popup, and the\\none that comes from .mh_profile) do get created.\\n\\nkre\\n\\nps: this is still using the version of the code form a day ago, I haven\\'t\\nbeen able to reach the cvs repository today (local routing issue I think).\\n\\n\\n\\n_______________________________________________\\nExmh-workers mailing list\\nExmh-workers@redhat.com\\nhttps://listman.redhat.com/mailman/listinfo/exmh-workers'"
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ham_emails[0].get_content().strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "'Help wanted.  We are a 14 year old fortune 500 company, that is\\ngrowing at a tremendous rate.  We are looking for individuals who\\nwant to work from home.\\n\\nThis is an opportunity to make an excellent income.  No experience\\nis required.  We will train you.\\n\\nSo if you are looking to be employed from home with a career that has\\nvast opportunities, then go:\\n\\nhttp://www.basetel.com/wealthnow\\n\\nWe are looking for energetic and self motivated people.  If that is you\\nthan click on the link and fill out the form, and one of our\\nemployement specialist will contact you.\\n\\nTo be removed from our link simple go to:\\n\\nhttp://www.basetel.com/remove.html\\n\\n\\n4139vOLW7-758DoDY1425FRhM1-764SMFc8513fCsLl40'"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "spam_emails[6].get_content().strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_email_structure(email):\n",
    "    if isinstance(email, str):\n",
    "        return email\n",
    "    payload = email.get_payload()\n",
    "    if isinstance(payload, list):\n",
    "        return \"multipart({})\".format(\", \".join([\n",
    "            get_email_structure(sub_email)\n",
    "            for sub_email in payload\n",
    "        ]))\n",
    "    else:\n",
    "        return email.get_content_type()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "Counter({1: 1, 4: 1, 3: 3, 2: 2})"
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from collections import Counter\n",
    "Counter([1,4,3,3,2,2,3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "'text/html'"
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "get_email_structure(spam_emails[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def structure_counter(emails):\n",
    "    structure = Counter()\n",
    "    for email in emails:\n",
    "        structure[get_email_structure(email)] += 1\n",
    "    return structure"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "[('text/plain', 218),\n ('text/html', 183),\n ('multipart(text/plain, text/html)', 45),\n ('multipart(text/html)', 20),\n ('multipart(text/plain)', 19),\n ('multipart(multipart(text/html))', 5),\n ('multipart(text/plain, image/jpeg)', 3),\n ('multipart(text/html, application/octet-stream)', 2),\n ('multipart(text/plain, application/octet-stream)', 1),\n ('multipart(text/html, text/plain)', 1),\n ('multipart(multipart(text/html), application/octet-stream, image/jpeg)', 1),\n ('multipart(multipart(text/plain, text/html), image/gif)', 1),\n ('multipart/alternative', 1)]"
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "structure_counter(spam_emails).most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "[('text/plain', 2408),\n ('multipart(text/plain, application/pgp-signature)', 66),\n ('multipart(text/plain, text/html)', 8),\n ('multipart(text/plain, text/plain)', 4),\n ('multipart(text/plain)', 3),\n ('multipart(text/plain, application/octet-stream)', 2),\n ('multipart(text/plain, text/enriched)', 1),\n ('multipart(text/plain, application/ms-tnef, text/plain)', 1),\n ('multipart(multipart(text/plain, text/plain, text/plain), application/pgp-signature)',\n  1),\n ('multipart(text/plain, video/mng)', 1),\n ('multipart(text/plain, multipart(text/plain))', 1),\n ('multipart(text/plain, application/x-pkcs7-signature)', 1),\n ('multipart(text/plain, multipart(text/plain, text/plain), text/rfc822-headers)',\n  1),\n ('multipart(text/plain, multipart(text/plain, text/plain), multipart(multipart(text/plain, application/x-pkcs7-signature)))',\n  1),\n ('multipart(text/plain, application/x-java-applet)', 1)]"
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "structure_counter(ham_emails).most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "X = np.array(ham_emails  + spam_emails)\n",
    "y = np.hstack((np.zeros(len(ham_emails)) , np.ones(len(spam_emails)))) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = 0.2,stratify = y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re \n",
    "from html import unescape\n",
    "\n",
    "def html_to_plain_text(html):\n",
    "    text = re.sub(r'<head.*?>.*?</head.*?>' , '' , html, flags=re.M | re.S | re.I)\n",
    "    text = re.sub(r'<a\\s.*?>',' HYPERLINK ', text, flags=re.M| re.S| re.I)\n",
    "    text = re.sub(r'<.*?>','', text, flags=re.M| re.S)\n",
    "    text = re.sub(r'(\\s*\\n)+', '\\n', text, flags = re.M| re.S)\n",
    "    return unescape(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "def email_to_text(email):\n",
    "    html = None\n",
    "    for part in email.walk():\n",
    "        ctype = part.get_content_type()\n",
    "        if not ctype in (\"text/plain\", \"text/html\"):\n",
    "            continue\n",
    "        try:\n",
    "            content = part.get_content()\n",
    "        except: # 解决编码问题\n",
    "            content = str(part.get_payload())\n",
    "        if ctype == \"text/plain\":\n",
    "            return content\n",
    "        else:\n",
    "            html = content\n",
    "    if html:\n",
    "        return html_to_plain_text(html)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import nltk \n",
    "from urlextract import URLExtract"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.base import BaseEstimator, TransformerMixin\n",
    "class EmailToWordCounterTransformer(BaseEstimator, TransformerMixin):\n",
    "    def __init__(self, strip_headers=True, lower_case=True, remove_punctuation=True,\n",
    "                 replace_urls=True, replace_numbers=True,  stemmer=None):\n",
    "        self.strip_headers = strip_headers\n",
    "        self.lower_case = lower_case\n",
    "        self.remove_punctuation = remove_punctuation\n",
    "        self.replace_urls = replace_urls\n",
    "        self.replace_numbers = replace_numbers\n",
    "        self.stemmer = stemmer\n",
    "    def fit(self, X, y=None):\n",
    "        return self\n",
    "    def transform(self, X, y=None):\n",
    "        X_transformed = []\n",
    "        for email in X:\n",
    "            text = email_to_text(email) or \"\"\n",
    "            if self.lower_case:\n",
    "                text = text.lower()\n",
    "            if self.replace_urls:\n",
    "                extractor = URLExtract()\n",
    "                urls = list(set(extractor.find_urls(text)))\n",
    "                urls.sort(key=lambda url: len(url), reverse=True)\n",
    "                for url in urls:  # 替换url 为 ‘URL’\n",
    "                    text = text.replace(url, \" URL \")\n",
    "            if self.replace_numbers:  # 替换数字\n",
    "                text = re.sub(r'\\d+(?:\\.\\d*(?:[eE]\\d+))?', 'NUMBER', text)\n",
    "            if self.remove_punctuation:  # 删除标点符号\n",
    "                text = re.sub(r'\\W+', ' ', text, flags=re.M)\n",
    "            word_counts = Counter(text.split())\n",
    "            if self.stemmer:\n",
    "                stemmed_word_counts = Counter()\n",
    "                for word, count in word_counts.items():\n",
    "                    stemmed_word = self.stemmer.stem(word)\n",
    "                    stemmed_word_counts[stemmed_word] += count\n",
    "                word_counts = stemmed_word_counts\n",
    "            X_transformed.append(word_counts)\n",
    "        return np.array(X_transformed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.sparse import csr_matrix\n",
    "\n",
    "class WordCounterToVectorTransformer(BaseEstimator, TransformerMixin):\n",
    "    def __init__(self, vocabulary_size = 1000):\n",
    "        self.vocabulary_size = vocabulary_size  # 词汇量\n",
    "    def fit(self, X, y = None):\n",
    "        total_count = Counter()\n",
    "        for word_count in X:\n",
    "            for word, count in word_count.items():\n",
    "                total_count[word] += min(count, 10)\n",
    "        most_common = total_count.most_common()[:self.vocabulary_size]\n",
    "        self.most_common_ = most_common\n",
    "        self.vocabulary_ = {word: index + 1 for index, (word, count) in enumerate(most_common)}\n",
    "        return self\n",
    "    def transform(self, X, y = None):\n",
    "        rows = []\n",
    "        cols = []\n",
    "        data = []\n",
    "        for row, word_count in enumerate(X):\n",
    "            for word, count in word_count.items():\n",
    "                rows.append(row) # 训练集 实例个数\n",
    "                cols.append(self.vocabulary_.get(word, 0)) # 取得单词在词汇表中的索引位置，0代表未出现在词汇表中\n",
    "                data.append(count)\n",
    "        return csr_matrix((data, (rows, cols)), shape=(len(X), self.vocabulary_size + 1)) # 输出稀疏矩阵 +1因为第一列要显示未出现在词汇表中的单词统计数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.pipeline import Pipeline\n",
    "preprocess_pipeline = Pipeline([\n",
    "    (\"email_to_wordcount\", EmailToWordCounterTransformer()),\n",
    "    (\"wordcount_to_vector\", WordCounterToVectorTransformer()),\n",
    "])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train_transformed = preprocess_pipeline.fit_transform(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": "[Parallel(n_jobs=-1)]: Using backend LokyBackend with 8 concurrent workers.\n[Parallel(n_jobs=-1)]: Done   3 out of   3 | elapsed:    1.6s finished\n"
    },
    {
     "data": {
      "text/plain": "0.9870833333333334"
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.model_selection import cross_val_score\n",
    "\n",
    "log_clf = LogisticRegression(solver=\"liblinear\", random_state=42) # 采用逻辑回归分类器\n",
    "score = cross_val_score(log_clf, X_train_transformed, y_train, cv=3, n_jobs= -1, verbose=3)\n",
    "score.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": "/home/tecty/.local/lib/python3.8/site-packages/sklearn/svm/_base.py:946: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.\n  warnings.warn(\"Liblinear failed to converge, increase \"\n"
    }
   ],
   "source": [
    "from sklearn.metrics import precision_score, recall_score\n",
    "\n",
    "X_test_transformed = preprocess_pipeline.transform(X_test)\n",
    "\n",
    "log_clf = LogisticRegression(solver=\"liblinear\", random_state=42)\n",
    "log_clf.fit(X_train_transformed, y_train)\n",
    "\n",
    "y_pred = log_clf.predict(X_test_transformed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "(0.989247311827957, 0.92)"
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "( precision_score(y_test, y_pred),recall_score(y_test, y_pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "array([0.53761964, 0.46262407, 0.54608947, 0.52821462])"
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.neighbors import KNeighborsRegressor\n",
    "\n",
    "knn_clf = KNeighborsRegressor()\n",
    "cross_val_score(knn_clf, X_train_transformed, y_train, cv=4 , n_jobs= -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "array([0.985     , 0.98333333, 0.98      , 0.98666667])"
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "forest_clf = RandomForestClassifier()\n",
    "cross_val_score(forest_clf, X_train_transformed, y_train, cv=4 , n_jobs= -1)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.1-final"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}